import abc
import random
from typing import Callable

from centralized_verification.agents.single_agent import SingleAgentLearner
from centralized_verification.agents.utils import default_epsilon_schedule
from centralized_verification.shields.shield import AgentResult, AgentUpdate
from centralized_verification.utils import convert_gym_space_to_q_shape, TrainingProgress


class QLearner(SingleAgentLearner, abc.ABC):
    def __init__(self, obs_space, action_space, discount=0.9, alpha_index=1,
                 epsilon_scheduler: Callable[[TrainingProgress], float] = default_epsilon_schedule,
                 evaluation_epsilon: float = 0.0):
        self.obs_space = obs_space
        self.num_actions = convert_gym_space_to_q_shape(action_space)[0]
        self.discount = discount
        self.alpha_index = alpha_index
        self.epsilon_scheduler = epsilon_scheduler
        self.evaluation_epsilon = evaluation_epsilon

        self.log_last_eps = epsilon_scheduler(TrainingProgress(0, 0))

    @abc.abstractmethod
    def get_greedy_action(self, observation):
        raise NotImplemented

    def get_action(self, observation, training_progress):
        epsilon = self.epsilon_scheduler(
            training_progress) if training_progress is not None else self.evaluation_epsilon
        self.log_last_eps = epsilon
        if random.random() < epsilon:
            best_action = random.randint(0, self.num_actions - 1)
        else:
            best_action = self.get_greedy_action(self.transform_obs(observation))

        return best_action

    @abc.abstractmethod
    def update_q(self, obs, action, next_obs, rew, done, step_num, training_progress):
        pass

    def transform_obs(self, obs):
        return obs

    def update_q_with_agent_update(self, obs, agent_update: AgentUpdate, next_obs, rew, done, step_num,
                                   training_progress):
        self.update_q(obs, agent_update.action, next_obs, agent_update.get_modified_reward(rew), done, step_num,
                      training_progress)

    def observe_transition(self, obs, shield_result: AgentResult, next_obs, rew, done, step_num, training_progress):
        obs = self.transform_obs(obs)
        next_obs = self.transform_obs(next_obs)
        self.update_q_with_agent_update(obs, shield_result.real_action, next_obs, rew, done, step_num,
                                        training_progress)
        if shield_result.augmented_action is not None:
            self.update_q_with_agent_update(obs, shield_result.augmented_action, next_obs, rew, done, step_num,
                                            training_progress)

    def get_log_dict(self):
        return {
            "epsilon": self.log_last_eps
        }
